GraphTranslator: Aligning Graph Model to Large Language Model for Open-ended Tasks

本文最后更新于 2024年8月5日 下午

GraphTranslator: Aligning Graph Model to Large Language Model for Open-ended Tasks

论文要做什么

通过LLM(大语言模型)与GM(图模型)相结合,实现一个既能解决预定义任务又能解决开放式任务的模型。

实验结果表明,该方法在零样本分类等任务中表现出色,并具有广泛的应用前景。

论文背景

现有将LLM应用到GM的工作主要有两种

  • 将LLM作为Enhancer(增强器)
    • 弊端:不能处理开放式任务
  • 将LLM作为Predictor(预测器)
    • 弊端:容易出现幻觉答案

作者通过GraphTranslator模型对齐GM和LLM,来扩展GM处理开放式任务的能力。

论文结果

通过现实数据集评估了GraphTranslator,结果表明了GraphTranslator在零样本分类任务和图问答任务上的有效性。

论文方法

model

上图左半部分可以看到论文的方法包括四个模块:冻结的GM,冻结的LLM,ProducerTranslator

【Whalepaper第100期】NLP论文研读:GraphTranslator-结合预训练的图模型与大型语言模型来处理预定义和开放式任务【精准空降到 12:46】

1. 冻结的GM

  • 冻结的GM主要作用:为数据集的所有节点生成嵌入向量\(z_v\)
    • \(z_v\)​会传给Producer和Translator的交叉注意力层(Cross Attention)
  • 给定一个图\(\mathcal G = (\mathcal V, A, \{s_v\}_{v \in \mathcal V}), A\in \{0,1\}^{N\times N}\)
    • \(\mathcal V\)是所有节点,\(A\)是图的邻接矩阵,\(s_v\)是节点\(v\)的文本描述
  • 典型的图神经网络表示为\(\mathcal g_\theta(A,X)\)
    • \(\theta\)是可学习的参数,\(X\)是通过词袋(BoW)处理\(\left\{s_v\right\}\)得到的embedding,\(\mathcal g\)使用的是GraphSAGE
  • GraphSAGE在目标节点\(v\)周围采样固定大小(2跳)的邻居\(\mathcal N(v)\),构成节点\(v\)的子图,然后将节点对上一层嵌入\(h_v^{k-1}\)与聚合的领域向量\(\left\{h_u^{k-1},\forall u\in \mathcal N(v)\right\}\)连接起来。(将邻居的当前embedding聚合起来,然后与目标节点的embedding拼接起来,逐层传播聚合,最后可以得到\(z_v\)
    • \(h_v^k = \sigma\left(W^k \cdot CONCAT\left(h_v^{k-1} \cup AGGREGATE_k\left\{h_u^{k-1},\forall u \in \mathcal N(v)\right\}\right)\right)\)​​
      • \(W^k\)\(\mathcal g\)的一个参数,表示第\(k\)层的权重矩阵
    • \(z_v = \mathcal g_{\theta^*}\left(A, X\right)_v\)
      • 这里的\(\theta ^*\)表示参数是固定的,所以这个模块是冻结的GM。(作者说GraphSAGE是阿里已上线的参数,没有改过)

2. Producer

  • 作用利用LLM生成节点嵌入和文本描述的匹配数据,并文本化节点信息。对齐GM和LLM

    • 生成一个节点文本对
    • node_id embedding paper_summary citepapers_summary title
      42 -0.077210054, 0.26279667, 0.82795596, ... This paper studies ... These papers cover ... contact representations of sparse planar graphs
  • 构建对齐数据\(P=\left\{(z_v,t_v)\right\}_{i=1}^{\mathcal N_p}\)

    • \(t_v = \left\{t_v^s,t_v^{\mathcal N(v)},t_v^c\right\}\),从左到右分别是:自己的属性信息,邻居的属性信息,节点间的共性
      • 实际代码中发现只有前两个,第三个在代码和附录的prompt中均未找到
    • \(t_v\)是使用"Chain of Thought"(CoT)引导GPT生成的高质量描述
      • image-20240709110127287
      • \(t_v^s\): 通常节点属性(文本或数字数据)被视为每个节点的特征,使用词袋模型实现。Producer使用LLM总结并分析训练集中每个节点\(v\)的属性,得到节点描述记为\(t_v^s\)
      • \(t_v^{\mathcal N(v)}\): GraphSAGE随机抽样邻居节点\(\mathcal N(v)\)的子集并聚合他们的表示,得到neighbor embedding。节点和邻居信息通过加权求和或进一步融合。Producer使用LLM总结\(\mathcal N(v)\)的属性,得到邻居信息的描述记为\(t_v^{\mathcal N(v)}\)
    • \(z_v\)也就是上表中的embedding,实现代码位于producer.py的133行
  • 具体的代码位于producer.py的LLM.inference_chatglm_arxiv

3. Translator

  • 作用:将节点嵌入转化为token,实现GM和LLM对齐

  • 执行流程:

    • 从上边流程图可知,输入包括

      • Query Token \(Q\)

        • \(Q\)是需要学习的参数,所以开始时赋值为0

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          # Translator/models/translator_models/translator.py init_Qformer
          query_tokens = nn.Parameter(
          torch.zeros(1, num_query_token, encoder_config.hidden_size)
          )
          query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)

          # Translator/models/translator_models/translator_qformer_arxiv.py TranslatorQformerArxiv.forward
          # 这里的behavior_embeds是z_v
          # 将Q的形状扩展为z_v的形状
          query_tokens = self.query_tokens.expand(behavior_embeds.shape[0], -1, -1)
        • 在第一次训练阶段的第一步就是训练\(Q\)进行学习

      • Description Tokens \(t_v\)​​

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        # Translator/models/translator_models/translator_qformer_arxiv.py
        def init_tokenizer(cls):
        tokenizer = BertTokenizer.from_pretrained("../models/bert-base-uncased")
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        return tokenizer

        # Translator/models/translator_models/translator_qformer_arxiv.py TranslatorQformerArxiv.__init__
        self.tokenizer = self.init_tokenizer()

        # Translator/models/translator_models/translator_qformer_arxiv.py TranslatorQformerArxiv.forward
        text_tokens = self.tokenizer(
        text, # 这里的text是上边表格中一条数据的`paper_summary`
        padding="max_length",
        truncation=True,
        max_length=self.max_txt_len,
        return_tensors="pt",
        ).to(behavior_embeds.device)
      • Node Embedding \(z_v\)​​

        • 在Producer模块存到了数据中,对应上表的embedding

          1
          behavior_embeds = torch.unsqueeze(samples[1], dim=1)
    • Translator中包括两个编码器\(f_t(\cdot)\)\(f_z(\cdot)\)

      • \(f_t(\cdot)\)基于BERT实现,用于提取语言特征\(T_v = f_t(t_v)\)
        • \(f_t(\cdot)\)​包含12层Transformer块
      • \(f_z(\cdot)\)基于Transformer网络
        • 以M个可学习的token embeddings作为输入,称为query token \(Q=\{q_i\}_{i=1}^M\)
        • 输出M个特征\(H_v=f_z(Q,z_v)=\{h_{v,i}\}_{i=1}^M\),提取\(z_v\)中与\(t_v\)最相关的信息
        • 使用自注意力层相互交互,通过交叉注意力层(Cross Attention)与节点嵌入\(z_v\)交互,并通过在\(f_t\)\(f_z\)之间共享的自注意力层(Shared Self-Attention)与\(t_v\)通信
    • 损失值将由\(T_v\)\(H_v\)计算得出,具体在下面训练部分说明

4.Train

训练分为了两个阶段

  • Train-1: 对齐GM和Text
  • Train-2: 对齐GM和LLM

Train1

对齐GM和Text,也就是对齐\(H_v=\{h_{v, i}\}_{i=1}^M\)\(\tilde t_v\)\(\tilde t _v\)\(T_v\)[CLS]token嵌入)

对应的代码为

根据作者的注释可以将forward函数分为4个部分(第一部分姑且命名为Text Feature Extractor)

Text Feature Extractor

这部分为论文中提到的\(f_t\), 使用Qformer.bert处理text_tokens(\(t_v\))得到text_feat(\(T_v\)).

  • behavior_feats表示\(H_v = f_z(Q, z_v) = \left\{h_{v, i}\right\}_{i=1}^M\), M指可学习的token embedding的数量即num_query_tokens

    • 对于节点嵌入\(z_v\) ,我们还采用基于 Transformer 的网络\(f_z(\cdot)\) ,以\(M\)个可学习的标记嵌入作为输入(称为查询标记\(Q=\left\{q_i\right\}_{i=1}^M\),输出\(𝑀\)特征\(H_v=\left\{h_{v,i}\right\}_{𝑖=1}^𝑀\)\(H_v=f_z(Q,z_v)\) ,提取与\(𝑡_𝑣\)最相关的\(z_𝑣\)​信息。

  • self.tokenizer在初始化中使用tokenizer.add_special_tokens({"bos_token": "[DEC]"})实现了将[CLS]标签替换为[DEC]

    • BERT模型的默认输入格式为[CLS] xxxxx [SEP],文本数据的开头通常会有[CLS]标签
  • 使用bert提取\(t_v\)的特征得到\(T_v\)

    • 对于文本描述\(t_v\) ,我们利用文本编码器\(f_t(\cdot)\)(例如 BERT [4] )提取语言特征\(T_v = f_t(t_v)\),其中\(f_t(\cdot)\)包含 12 层 Transformer 块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
behavior_embeds = torch.unsqueeze(samples[1], dim=1)    # z_v
text = samples[2]
behavior_embeds = behavior_embeds.to(self.device)
behavior_atts = torch.ones(behavior_embeds.size()[:-1], dtype=torch.long).to(behavior_embeds.device)

query_tokens = self.query_tokens.expand(behavior_embeds.shape[0], -1, -1) # Qurery Tokens Q

query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=behavior_embeds,
encoder_attention_mask=behavior_atts,
use_cache=True,
return_dict=True,
)

behavior_feats = F.normalize( # H_v
self.behavior_proj(query_output.last_hidden_state), dim=-1
)

text_tokens = self.tokenizer( # t_v
text,
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(behavior_embeds.device)

text_output = self.Qformer.bert(
text_tokens.input_ids,
attention_mask=text_tokens.attention_mask,
return_dict=True,
)
text_feat = F.normalize( # T_v
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
)
Image-text Contrastive

这里感觉应该是Graph-text Contrastive,不知道是不是作者打错了

  • 通过计算 \(h_{v,i}\)\(T_v\)的最相似的索引 与 \(\tilde t_v\)\(H_v\)的最相似的索引 的交叉熵损失的均值loss_itc

    • \(\tilde t_v\)\(T_v\)[CLS] token embedding

    • 对比目标通过最大化它们的相互信息来对齐\(H_v\)\(\tilde t_v\)。我们首先计算\(\tilde t_v\)\(H_v\)中每个token之间的成对相似度,并选择最高的一个作为相似度得分,然后将正对的相似度与负对的相似度进行对比。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
###============== Image-text Contrastive ===================###
behavior_feats_all = concat_all_gather(
behavior_feats
) # [batch_size*num_gpu, num_query_tokens, embed_dim] # torch.Size([8, 32, 256])
text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim]

sim_q2t = torch.matmul(
behavior_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
).squeeze()
# [batch_size, batch_size*num_gpu, num_query_tokens]

# image-text similarity: aggregate across all query tokens
sim_i2t, _ = sim_q2t.max(-1)
sim_i2t = sim_i2t / self.temp

# text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
sim_t2q = torch.matmul(
text_feat.unsqueeze(1).unsqueeze(1), behavior_feats_all.permute(0, 2, 1)
).squeeze()

# text-image similarity: aggregate across all query tokens
sim_t2i, _ = sim_t2q.max(-1)
sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu]

rank = 0
bs = behavior_embeds.size(0)
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
behavior_embeds.device
)

loss_itc = (
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
) / 2
Image-text Matching
  • 这里是计算正对相似度和负对相似度进行对比

    • 分别计算了text的负样本graph(不是这个text对应的图节点) 和 计算了graph的负样本text(不是这个graph对应的text)

    • 使用每个graph的负样本text的id(text_ids_all)作为BERT输入,

    • 对比目标通过最大化它们的相互信息来对齐\(H_v\)\(\tilde t_v\)。我们首先计算\(\tilde t_v\)\(H_v\)中每个token之间的成对相似度,并选择最高的一个作为相似度得分,然后将正对的相似度与负对的相似度进行对比。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
###============== Image-text Matching ===================###
text_input_ids_world = concat_all_gather(text_tokens.input_ids)
text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
behavior_embeds_world = all_gather_with_grad(behavior_embeds)
with torch.no_grad():
weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4
weights_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(0)
weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4
weights_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(0)

# select a negative image for each text
behavior_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
behavior_embeds_neg.append(behavior_embeds_world[neg_idx])
behavior_embeds_neg = torch.stack(behavior_embeds_neg, dim=0)

# select a negative text for each image
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(text_input_ids_world[neg_idx])
text_atts_neg.append(text_attention_mask_world[neg_idx])

text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)

text_ids_all = torch.cat(
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
) # pos, pos, neg
text_atts_all = torch.cat(
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
dim=0,
)

query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
behavior_embeds.device
)
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)

behavior_embeds_all = torch.cat(
[behavior_embeds, behavior_embeds_neg, behavior_embeds], dim=0
) # pos, neg, pos
behavior_atts_all = torch.ones(behavior_embeds_all.size()[:-1], dtype=torch.long).to(
behavior_embeds.device
)

output_itm = self.Qformer.bert(
text_ids_all,
query_embeds=query_tokens_itm,
attention_mask=attention_mask_all,
encoder_hidden_states=behavior_embeds_all,
encoder_attention_mask=behavior_atts_all,
return_dict=True,
)

vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
vl_output = self.itm_head(vl_embeddings) # 二元分类器
logits = vl_output.mean(dim=1) # 在序列长度上取均值

itm_labels = torch.cat(
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
dim=0,
).to(behavior_embeds.device)
loss_itm = F.cross_entropy(logits, itm_labels)
Image Captioning

标题的意思应该是基于给定的graph生成描述文本,Q-Former生成文本后会得到一个损失值

论文中没有找到对应的出处,GPT回答是“指导模型在训练过程中生成更高质量的文本”

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
##================= Image Captioning ========================##
decoder_input_ids = text_tokens.input_ids.clone()
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id # 将第一个位置设置为开始标记[BOS]
labels = decoder_input_ids.masked_fill(
decoder_input_ids == self.tokenizer.pad_token_id, -100
) # 将填充标记[PAD]替换为-100,这样计算损失时可以被忽略

query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
behavior_embeds.device
)
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
lm_output = self.Qformer( # 传给Q-Former用于生成文本
decoder_input_ids,
attention_mask=attention_mask,
past_key_values=query_output.past_key_values,
return_dict=True,
labels=labels,
)

loss_lm = lm_output.loss
Final
  • 模型的损失使用前面三个阶段的损失和
1
2
3
4
5
6
return TranslatorOutput(
loss=loss_itc + loss_itm + loss_lm,
loss_itc=loss_itc,
loss_itm=loss_itm,
loss_lm=loss_lm,
)

Train2

这一阶段的主要目标是训练Translator实现GM-LLM对齐

实现的代码主要位于Translator/models/translator_models/translator_chatglm_arxiv.py

  • 训练过程中重新使用Producer的结果作为数据集,没有直接存储stage1的\(H_v\)使用

  • 这里的\(H_v\)(vtokens)被一个线性层self.chatglm2_proj映射为LLM相同的维度上

  • prepare_lm_input的主要返回三个参数input_ids,labels,inputs_embeds

    • input_ids = a_ids + b_ids
      • a_ids = [IMAGE_TOKEN_ID] * nvtoken + tokenizer.encode(text, add_special_tokens=False),IMAGE_TOKEN_ID是固定参数101, nvtoken\(H_v\)​的特征数,text是论文流程图中的Instruction
      • b_ids = tokenizer.encode(ans, add_special_tokens=False),ans\(t_v\)中的一段
      • input_ids = [IMAGE_TOKEN_ID]*nvtoken + 'Question: Please summarize the topic and content of the paper and its citations in English. Answer:' + 'This paper studies ...'文本在程序中均为embedding
    • label = input_ids.detach().clone(), label[:context_length]=-100
      • label会复制input_ids, 然后将a_ids的位置改为-100, 用于计算损失时忽略
    • inputs_embeds=self.chatglm2_model.transformer.embedding.word_embeddings(input_ids)
      • input_ids通过LLM的嵌入层转换为嵌入向量
      • vtoken(\(H_v\))插入到相应位置的嵌入向量中,inputs_embeds[:, nvtoken_id: nvtoken_id + nvtoken] = vtokens
      • 最后将嵌入向量的形状调整为适合LLM输入的格式,inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
multimodal_embeds = samples[1].unsqueeze(dim=1).to(self.device)         # z_v
text = samples[2] # t_v
instruction = ['\nQuestion: Please summarize the topic and content of the paper and its citations in English. Answer:' for _ in range(len(text))]
device = self.Qformer.bert.device

multimodal_atts = torch.ones(multimodal_embeds.size()[:-1], dtype=torch.long).to(device)

query_tokens = self.query_tokens.expand(multimodal_embeds.shape[0], -1, -1).to(device)

query_output = self.Qformer.bert( # H_v
query_embeds=query_tokens,
encoder_hidden_states=multimodal_embeds,
encoder_attention_mask=multimodal_atts,
return_dict=True,
)
vtokens = self.chatglm2_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :]) # H_v通过线性层投影到LLM相同的维度上

input_ids, labels, inputs_embeds = self.prepare_lm_input(
vtokens=vtokens, text_input=instruction, answer=text
)

outputs = self.chatglm2_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
return_dict=True,
labels=labels,
)

loss = outputs.loss

return {"loss": loss, "vtokens": vtokens, "logits": outputs.logits}

GraphTranslator: Aligning Graph Model to Large Language Model for Open-ended Tasks
https://tippye.github.io/2024/07/19/GraphTranslator: Aligning Graph Model to Large Language Model for Open-ended Tasks/
作者
Tippy
发布于
2024年7月19日
许可协议